Introduce Mega-C++ to reduce CPU overhead#3099
Conversation
|
/te-ci pytorch L1 |
| m.def("te_general_grouped_gemm_for_discrete_out", | ||
| &transformer_engine::pytorch::te_general_grouped_gemm_for_discrete_out, | ||
| "Grouped GEMM for discrete output list"); | ||
| m.def("megacpp_grouped_mlp_forward", &transformer_engine::pytorch::megacpp_grouped_mlp_forward, |
There was a problem hiding this comment.
We should expose these functions within the tex.grouped_mlp_experimental submodule:
TransformerEngine/transformer_engine/pytorch/csrc/extensions/pybind.cpp
Lines 647 to 650 in 3fffa55
There was a problem hiding this comment.
It would make more sense to organize:
csrc/
├── extensions/
│ ├── grouped_mlp_experimental/
│ │ ├── megacpp.cpp
│ │ └── grouped_mlp_experimental.cpp
│ ├── pybind.cpp
│ └── ...
If we implement more mega-C++ impls in the future, I don't see a reason why they would be more similar to each other than to the block they are fusing.
| name: str | ||
| is_scaled: bool | ||
| is_gated: bool | ||
| glu_interleave_size: int |
There was a problem hiding this comment.
Is it worth supporting GLU interleaving in the mega-C++ path? The only benefit is to support the fused GEMM+GLU kernel, and otherwise the unnecessary memory-bound kernel means perf is a lost cause. If we can simplify our optimized code paths, then it's worth it.
There was a problem hiding this comment.
The only benefit is to support the fused GEMM+GLU kernel
I do hope in the future we can launch CuteDSL fused kernels in C++ with some TVM-FFI tricks, otherwise we are forced to choose either better kernel fusions or less CPU overhead. Currently the CuteDSL fusion path is very CPU bounded for small models and we rely on CUDA graph and paged stashing for it to work well
There was a problem hiding this comment.
Why not make a separate fused op for mega-C++ CuTe DSL? It'll make the implementations less entangled, so there are fewer edge cases or complications that an agent might misunderstand.
There was a problem hiding this comment.
yes, if there is a cutedsl version with mega C++, it's gonna take its own code path with zero code reuse since it's gonna be agent-assisted coding anyway
| # Explicit env opt-in gives megacpp first chance. Unsupported recipes intentionally | ||
| # return the ops unchanged so lower-priority recipe-specific fusers remain the | ||
| # fallback path. | ||
| register_forward_fusion(fuse_forward_megacpp_ops, prepend=True) |
There was a problem hiding this comment.
The GEMM+act fusions provide better GPU perf, so I think they should take higher priority than mega-C++. Basically, I see mega-C++ as "we can't do any better on GPU than the unfused impl, but at least we can make the CPU overhead very small".
There was a problem hiding this comment.
Current order is follows:
- check env var
- env var = 1, then check supported recipe for mega-C++, so bf16 is supported, not mxfp8 / nvfp4
- then for mxfp8, nvfp4, mega-C++ does fallback and check for the next fusion.
The reasoning is that, I do not want the compromise of either better fusion or less host bound, so for future mxfp8 support, we can do the following two things:
- directly do cuteDSL integration directly with tvm-ffi and do cublas as a backup plan
- maybe add a new value to NVTE_MEGACPP_GROUPED_LINEAR=forced, so for users who cannot enable cuda graph for some reason, they can enforce C++ when they know that their training is more host bound
7ab8bc6 to
08a5800
Compare
Signed-off-by: zhongboz <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
for more information, see https://pre-commit.ci
08a5800 to
9d91d47
Compare
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com>
23ae840 to
07b2836
Compare
Greptile SummaryThis PR introduces "Mega-C++" (
Confidence Score: 4/5The functional path — forward GEMM → scaled activation → backward GEMM → wgrad — is well-implemented and covered by integration tests. All findings are quality/efficiency concerns rather than correctness failures in the happy path. The core BF16 grouped MLP pipeline is logically correct and backed by both C++ unit tests and Python integration tests. Issues found: dead-code branch that can confuse future recipe additions; unbounded per-stream scratch cache that can accumulate GPU allocations; overly broad input_requires_grad causing unnecessary dx computation in frozen-input scenarios; and a delay_wgrad incompatibility caught only at backward time. None affect correctness for the supported BF16/FP16, no-delay_wgrad configuration. forward_grouped_mlp_megacpp.py warrants a second look for the recipe predicate, LRU cache bound, and input_requires_grad assignment. backward_grouped_mlp_megacpp.py should move the delay_wgrad guard earlier in the fusion lifecycle. Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant Py as Python Autograd
participant Fwd as ForwardGroupedMLP_MegaCpp
participant Bwd as BackwardGroupedMLP_MegaCpp
participant CPP as C++ forward/backward
participant GEMM as cuBLAS Grouped GEMM
participant Act as Scaled Activation Kernels
Py->>Fwd: fuser_forward(input, split_sizes, act_scales)
Fwd->>Fwd: resolve weights / bias / scratch
Fwd->>CPP: tex.megacpp_grouped_mlp_forward
CPP->>GEMM: FC1 grouped GEMM
CPP->>Act: nvte_scaled_swiglu / clamped_swiglu / srelu
CPP->>GEMM: FC2 grouped GEMM
CPP-->>Fwd: output, offsets, fc1_preact, fc2_x
Fwd->>Fwd: save_for_backward
Fwd-->>Py: fc2_out
Py->>Bwd: fuser_backward(grad_output)
Bwd->>CPP: tex.megacpp_grouped_mlp_backward
CPP->>GEMM: FC2 wgrad
CPP->>GEMM: FC2 dgrad
CPP->>Act: nvte_scaled_dswiglu / dsrelu
CPP->>GEMM: FC1 wgrad
CPP->>GEMM: FC1 dgrad
CPP-->>Bwd: grad_input, fc1_dy, grad_act_scales, wgrads
Bwd->>Bwd: compute_grouped_dbias (Triton)
Bwd-->>Py: grad_input, grad_params, grad_extra_inputs
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant Py as Python Autograd
participant Fwd as ForwardGroupedMLP_MegaCpp
participant Bwd as BackwardGroupedMLP_MegaCpp
participant CPP as C++ forward/backward
participant GEMM as cuBLAS Grouped GEMM
participant Act as Scaled Activation Kernels
Py->>Fwd: fuser_forward(input, split_sizes, act_scales)
Fwd->>Fwd: resolve weights / bias / scratch
Fwd->>CPP: tex.megacpp_grouped_mlp_forward
CPP->>GEMM: FC1 grouped GEMM
CPP->>Act: nvte_scaled_swiglu / clamped_swiglu / srelu
CPP->>GEMM: FC2 grouped GEMM
CPP-->>Fwd: output, offsets, fc1_preact, fc2_x
Fwd->>Fwd: save_for_backward
Fwd-->>Py: fc2_out
Py->>Bwd: fuser_backward(grad_output)
Bwd->>CPP: tex.megacpp_grouped_mlp_backward
CPP->>GEMM: FC2 wgrad
CPP->>GEMM: FC2 dgrad
CPP->>Act: nvte_scaled_dswiglu / dsrelu
CPP->>GEMM: FC1 wgrad
CPP->>GEMM: FC1 dgrad
CPP-->>Bwd: grad_input, fc1_dy, grad_act_scales, wgrads
Bwd->>Bwd: compute_grouped_dbias (Triton)
Bwd-->>Py: grad_input, grad_params, grad_extra_inputs
Reviews (1): Last reviewed commit: "integrate fused scaled swiglu and srelu" | Re-trigger Greptile |
| def _megacpp_supports_recipe(recipe: Optional[Recipe]) -> bool: | ||
| """Whether megacpp is a valid candidate for the active quantization recipe. | ||
|
|
||
| Today the C++ implementation is BF16/FP16-only, so only the no-recipe path | ||
| is supported. Returning False for FP8 recipes is intentional: it leaves the | ||
| op list unchanged so the existing MXFP8/NVFP4 CuTe DSL fusers can match. | ||
| Future MXFP8/NVFP4 support should be enabled by changing this predicate, | ||
| not by reordering fusion registrations. | ||
| """ | ||
| if recipe is None: | ||
| return True | ||
| if recipe.mxfp8() or recipe.nvfp4(): | ||
| return False | ||
| return False |
There was a problem hiding this comment.
The
if recipe.mxfp8() or recipe.nvfp4(): return False branch is dead code — any non-None recipe that is not mxfp8/nvfp4 falls through to the identical final return False. The function is equivalent to return recipe is None; the current form misleads readers about which cases are being explicitly handled and can confuse future engineers adding support for a new recipe type.
| def _megacpp_supports_recipe(recipe: Optional[Recipe]) -> bool: | |
| """Whether megacpp is a valid candidate for the active quantization recipe. | |
| Today the C++ implementation is BF16/FP16-only, so only the no-recipe path | |
| is supported. Returning False for FP8 recipes is intentional: it leaves the | |
| op list unchanged so the existing MXFP8/NVFP4 CuTe DSL fusers can match. | |
| Future MXFP8/NVFP4 support should be enabled by changing this predicate, | |
| not by reordering fusion registrations. | |
| """ | |
| if recipe is None: | |
| return True | |
| if recipe.mxfp8() or recipe.nvfp4(): | |
| return False | |
| return False | |
| def _megacpp_supports_recipe(recipe: Optional[Recipe]) -> bool: | |
| """Whether megacpp is a valid candidate for the active quantization recipe. | |
| Today the C++ implementation is BF16/FP16-only, so only the no-recipe path | |
| is supported. Any quantized recipe (MXFP8, NVFP4, or future types) returns | |
| False, leaving the op list unchanged so recipe-specific fusers can match. | |
| Future quantized-compute support should be enabled by adding an explicit | |
| `return True` here for the new recipe type, not by reordering fusions. | |
| """ | |
| return recipe is None |
| @functools.lru_cache(maxsize=None) | ||
| def _cached_grouped_gemm_scratch( | ||
| num_groups: int, | ||
| device_index: int, | ||
| _stream_handle: int, | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
The
@functools.lru_cache(maxsize=None) with a CUDA stream handle as part of the key creates an unbounded cache. Each unique (num_groups, device_index, stream_handle) triplet permanently holds three CUDA tensors. In pipeline-parallel configurations or test suites that construct many torch.cuda.Stream() objects over time, destroyed streams' allocations are never freed because lru_cache holds the only live reference. Bounding the cache (e.g., maxsize=64) caps the worst-case retained GPU memory.
| @functools.lru_cache(maxsize=None) | |
| def _cached_grouped_gemm_scratch( | |
| num_groups: int, | |
| device_index: int, | |
| _stream_handle: int, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| @functools.lru_cache(maxsize=64) | |
| def _cached_grouped_gemm_scratch( | |
| num_groups: int, | |
| device_index: int, | |
| _stream_handle: int, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs) | ||
| input_requires_grad = requires_grad | ||
| fc1_weight_requires_grad = requires_grad and fc1_weight_param.requires_grad | ||
| fc2_weight_requires_grad = requires_grad and fc2_weight_param.requires_grad |
There was a problem hiding this comment.
input_requires_grad is set to the generic requires_grad flag, which is True whenever any of the three op contexts requires a gradient — including the weight-only case. When the input tensor is frozen, the C++ backward still computes the full FC1 dgrad GEMM and discards the result. Using fc1_ctx.input_requires_grad as an additional gate matches the existing fuser convention and avoids the wasted computation.
| requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs) | |
| input_requires_grad = requires_grad | |
| fc1_weight_requires_grad = requires_grad and fc1_weight_param.requires_grad | |
| fc2_weight_requires_grad = requires_grad and fc2_weight_param.requires_grad | |
| requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs) | |
| input_requires_grad = requires_grad and fc1_ctx.input_requires_grad | |
| fc1_weight_requires_grad = requires_grad and fc1_weight_param.requires_grad | |
| fc2_weight_requires_grad = requires_grad and fc2_weight_param.requires_grad |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| if _delay_wgrad(fc_op, ctx): | ||
| raise ValueError("megacpp grouped MLP does not support delay_wgrad_compute=True.") |
There was a problem hiding this comment.
delay_wgrad_compute=True incompatibility is detected inside fuser_backward, which runs during .backward() — after the forward pass has already executed. A user who constructs a megacpp-fused model with delay_wgrad_compute=True will get through an entire forward step before hitting the ValueError. Moving this check to ForwardGroupedMLP_MegaCpp.__init__ or fuse_forward_megacpp_ops would surface the error at model-construction or fusion time instead.
|
/te-ci pytorch L1 |

Description
Assistant: GPT5.5 codex
Issue: #2897
Get rid of CPU overhead whenever CUDA Graph is not applicable. Guarded by NVTE_MEGACPP_GROUPED_LINEAR.
Drop-in replace grouped MLP, ie. FC1 - act - FC2. Target BF16 grouped gemm with cublas grouped gemm backend.
In the future, we can extend to mxfp8 / nvfp4 with cublas backend or even cuteDSL grouped gemm and call
cute.jitin C++: NVIDIA/cutlass#3289Recommend CUDA >= 13.2.1
Dependency of merge: #3132 => this PR is rebased on top of this branch.
TODO:
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: